
<!DOCTYPE html>

<html lang="zh">
  <head>
    <meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

    <title>3.4 模型构建 &#8212; 深入浅出PyTorch</title>
    
  <!-- Loaded before other Sphinx assets -->
  <link href="../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
<link href="../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">

    
  <link rel="stylesheet"
    href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">

    <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
    <link rel="stylesheet" href="../_static/styles/sphinx-book-theme.css?digest=62ba249389abaaa9ffc34bf36a076bdc1d65ee18" type="text/css" />
    <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
    <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
    <link rel="stylesheet" type="text/css" href="../_static/plot_directive.css" />
    
  <!-- Pre-loaded scripts that we'll load fully later -->
  <link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf">

    <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
    <script src="../_static/jquery.js"></script>
    <script src="../_static/underscore.js"></script>
    <script src="../_static/doctools.js"></script>
    <script>let toggleHintShow = 'Click to show';</script>
    <script>let toggleHintHide = 'Click to hide';</script>
    <script>let toggleOpenOnPrint = 'true';</script>
    <script src="../_static/togglebutton.js"></script>
    <script src="../_static/scripts/sphinx-book-theme.js?digest=f31d14ad54b65d19161ba51d4ffff3a77ae00456"></script>
    <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
    <script>window.MathJax = {"options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
    <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
    <link rel="index" title="索引" href="../genindex.html" />
    <link rel="search" title="搜索" href="../search.html" />
    <link rel="next" title="3.5 模型初始化" href="3.5%20%E6%A8%A1%E5%9E%8B%E5%88%9D%E5%A7%8B%E5%8C%96.html" />
    <link rel="prev" title="3.3 数据读入" href="3.3%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%85%A5.html" />
    <meta name="viewport" content="width=device-width, initial-scale=1" />
    <meta name="docsearch:language" content="zh">
    

    <!-- Google Analytics -->
    
  </head>
  <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="60">
<!-- Checkboxes to toggle the left sidebar -->
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation" aria-label="Toggle navigation sidebar">
<label class="overlay overlay-navbar" for="__navigation">
    <div class="visually-hidden">Toggle navigation sidebar</div>
</label>
<!-- Checkboxes to toggle the in-page toc -->
<input type="checkbox" class="sidebar-toggle" name="__page-toc" id="__page-toc" aria-label="Toggle in-page Table of Contents">
<label class="overlay overlay-pagetoc" for="__page-toc">
    <div class="visually-hidden">Toggle in-page Table of Contents</div>
</label>
<!-- Headers at the top -->
<div class="announcement header-item noprint"></div>
<div class="header header-item noprint"></div>

    
    <div class="container-fluid" id="banner"></div>

    

    <div class="container-xl">
      <div class="row">
          
<!-- Sidebar -->
<div class="bd-sidebar noprint" id="site-navigation">
    <div class="bd-sidebar__content">
        <div class="bd-sidebar__top"><div class="navbar-brand-box">
    <a class="navbar-brand text-wrap" href="../index.html">
      
      
      
      <h1 class="site-logo" id="site-title">深入浅出PyTorch</h1>
      
    </a>
</div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
  <i class="icon fas fa-search"></i>
  <input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" >
</form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
    <div class="bd-toc-item active">
        <p aria-level="2" class="caption" role="heading">
 <span class="caption-text">
  目录
 </span>
</p>
<ul class="current nav bd-sidenav">
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/index.html">
   第一章：PyTorch的简介和安装
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  <label for="toctree-checkbox-1">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.1%20PyTorch%E7%AE%80%E4%BB%8B.html">
     1.1 PyTorch简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.2%20PyTorch%E7%9A%84%E5%AE%89%E8%A3%85.html">
     1.2 PyTorch的安装
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.3%20PyTorch%E7%9B%B8%E5%85%B3%E8%B5%84%E6%BA%90.html">
     1.3 PyTorch相关资源
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/index.html">
   第二章：PyTorch基础知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  <label for="toctree-checkbox-2">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.1%20%E5%BC%A0%E9%87%8F.html">
     2.1 张量
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.2%20%E8%87%AA%E5%8A%A8%E6%B1%82%E5%AF%BC.html">
     2.2 自动求导
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.3%20%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97%E7%AE%80%E4%BB%8B.html">
     2.3 并行计算简介
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 current active has-children">
  <a class="reference internal" href="index.html">
   第三章：PyTorch的主要组成模块
  </a>
  <input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  <label for="toctree-checkbox-3">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul class="current">
   <li class="toctree-l2">
    <a class="reference internal" href="3.1%20%E6%80%9D%E8%80%83%EF%BC%9A%E5%AE%8C%E6%88%90%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%9A%84%E5%BF%85%E8%A6%81%E9%83%A8%E5%88%86.html">
     3.1 思考：完成深度学习的必要部分
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="3.2%20%E5%9F%BA%E6%9C%AC%E9%85%8D%E7%BD%AE.html">
     3.2 基本配置
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="3.3%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%85%A5.html">
     3.3 数据读入
    </a>
   </li>
   <li class="toctree-l2 current active">
    <a class="current reference internal" href="#">
     3.4 模型构建
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="3.5%20%E6%A8%A1%E5%9E%8B%E5%88%9D%E5%A7%8B%E5%8C%96.html">
     3.5 模型初始化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="3.6%20%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     3.6 损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="3.7%20%E8%AE%AD%E7%BB%83%E4%B8%8E%E8%AF%84%E4%BC%B0.html">
     3.7 训练和评估
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="3.8%20%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     3.8 可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="3.9%20%E4%BC%98%E5%8C%96%E5%99%A8.html">
     3.9 Pytorch优化器
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/index.html">
   第四章：PyTorch基础实战
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  <label for="toctree-checkbox-4">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/%E5%9F%BA%E7%A1%80%E5%AE%9E%E6%88%98%E2%80%94%E2%80%94FashionMNIST%E6%97%B6%E8%A3%85%E5%88%86%E7%B1%BB.html">
     基础实战——FashionMNIST时装分类
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/index.html">
   第五章：PyTorch模型定义
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  <label for="toctree-checkbox-5">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.1%20PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E7%9A%84%E6%96%B9%E5%BC%8F.html">
     5.1 PyTorch模型定义的方式
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.2%20%E5%88%A9%E7%94%A8%E6%A8%A1%E5%9E%8B%E5%9D%97%E5%BF%AB%E9%80%9F%E6%90%AD%E5%BB%BA%E5%A4%8D%E6%9D%82%E7%BD%91%E7%BB%9C.html">
     5.2 利用模型块快速搭建复杂网络
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.3%20PyTorch%E4%BF%AE%E6%94%B9%E6%A8%A1%E5%9E%8B.html">
     5.3 PyTorch修改模型
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.4%20PyTorh%E6%A8%A1%E5%9E%8B%E4%BF%9D%E5%AD%98%E4%B8%8E%E8%AF%BB%E5%8F%96.html">
     5.4 PyTorch模型保存与读取
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/index.html">
   第六章：PyTorch进阶训练技巧
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
  <label for="toctree-checkbox-6">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.1%20%E8%87%AA%E5%AE%9A%E4%B9%89%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     6.1 自定义损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.2%20%E5%8A%A8%E6%80%81%E8%B0%83%E6%95%B4%E5%AD%A6%E4%B9%A0%E7%8E%87.html">
     6.2 动态调整学习率
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-torchvision.html">
     6.3 模型微调-torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-timm.html">
     6.3 模型微调 - timm
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.4%20%E5%8D%8A%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83.html">
     6.4 半精度训练
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.5%20%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-imgaug.html">
     6.5 数据增强-imgaug
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.6%20%E4%BD%BF%E7%94%A8argparse%E8%BF%9B%E8%A1%8C%E8%B0%83%E5%8F%82.html">
     6.6 使用argparse进行调参
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E4%B8%8E%E8%BF%9B%E9%98%B6%E8%AE%AD%E7%BB%83%E6%8A%80%E5%B7%A7.html">
     PyTorch模型定义与进阶训练技巧
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/index.html">
   第七章：PyTorch可视化
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/>
  <label for="toctree-checkbox-7">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.1%20%E5%8F%AF%E8%A7%86%E5%8C%96%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84.html">
     7.1 可视化网络结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.2%20CNN%E5%8D%B7%E7%A7%AF%E5%B1%82%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     7.2 CNN可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.3%20%E4%BD%BF%E7%94%A8TensorBoard%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.3 使用TensorBoard可视化训练过程
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/index.html">
   第八章：PyTorch生态简介
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/>
  <label for="toctree-checkbox-8">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.1%20%E6%9C%AC%E7%AB%A0%E7%AE%80%E4%BB%8B.html">
     8.1 本章简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.2%20%E5%9B%BE%E5%83%8F%20-%20torchvision.html">
     8.2 torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.3%20%E8%A7%86%E9%A2%91%20-%20PyTorchVideo.html">
     8.3 PyTorchVideo简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.4%20%E6%96%87%E6%9C%AC%20-%20torchtext.html">
     8.4 torchtext简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/transforms%E5%AE%9E%E6%93%8D.html">
     transforms实战
    </a>
   </li>
  </ul>
 </li>
</ul>

    </div>
</nav></div>
        <div class="bd-sidebar__bottom">
             <!-- To handle the deprecated key -->
            
            <div class="navbar_extra_footer">
            Theme by the <a href="https://ebp.jupyterbook.org">Executable Book Project</a>
            </div>
            
        </div>
    </div>
    <div id="rtd-footer-container"></div>
</div>


          


          
<!-- A tiny helper pixel to detect if we've scrolled -->
<div class="sbt-scroll-pixel-helper"></div>
<!-- Main content -->
<div class="col py-0 content-container">
    
    <div class="header-article row sticky-top noprint">
        



<div class="col py-1 d-flex header-article-main">
    <div class="header-article__left">
        
        <label for="__navigation"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="right"
title="Toggle navigation"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-bars"></i>
  </span>

</label>

        
    </div>
    <div class="header-article__right">
<button onclick="toggleFullScreen()"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="bottom"
title="Fullscreen mode"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-expand"></i>
  </span>

</button>

<div class="menu-dropdown menu-dropdown-repository-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Source repositories">
      <i class="fab fa-github"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Source repository"
>
  

<span class="headerbtn__icon-container">
  <i class="fab fa-github"></i>
  </span>
<span class="headerbtn__text-container">repository</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/issues/new?title=Issue%20on%20page%20%2F第三章/3.4 模型构建.html&body=Your%20issue%20content%20here."
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Open an issue"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-lightbulb"></i>
  </span>
<span class="headerbtn__text-container">open issue</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/edit/master/第三章/3.4 模型构建.md"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Edit this page"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-pencil-alt"></i>
  </span>
<span class="headerbtn__text-container">suggest edit</span>
</a>

      </li>
      
    </ul>
  </div>
</div>

<div class="menu-dropdown menu-dropdown-download-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Download this page">
      <i class="fas fa-download"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="../_sources/第三章/3.4 模型构建.md.txt"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Download source file"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file"></i>
  </span>
<span class="headerbtn__text-container">.md</span>
</a>

      </li>
      
      <li>
        
<button onclick="printPdf(this)"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="left"
title="Print to PDF"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file-pdf"></i>
  </span>
<span class="headerbtn__text-container">.pdf</span>
</button>

      </li>
      
    </ul>
  </div>
</div>
<label for="__page-toc"
  class="headerbtn headerbtn-page-toc"
  
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-list"></i>
  </span>

</label>

    </div>
</div>

<!-- Table of contents -->
<div class="col-md-3 bd-toc show noprint">
    <div class="tocsection onthispage pt-5 pb-3">
        <i class="fas fa-list"></i> Contents
    </div>
    <nav id="bd-toc-nav" aria-label="Page">
        <ul class="visible nav section-nav flex-column">
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   3.4.1 神经网络的构造
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id3">
   3.4.2 神经网络中常见的层
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id4">
   3.4.3 模型示例
  </a>
 </li>
</ul>

    </nav>
</div>
    </div>
    <div class="article row">
        <div class="col pl-md-3 pl-lg-5 content-container">
            <!-- Table of contents that is only displayed when printing the page -->
            <div id="jb-print-docs-body" class="onlyprint">
                <h1>3.4 模型构建</h1>
                <!-- Table of contents -->
                <div id="print-main-content">
                    <div id="jb-print-toc">
                        
                        <div>
                            <h2> Contents </h2>
                        </div>
                        <nav aria-label="Page">
                            <ul class="visible nav section-nav flex-column">
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   3.4.1 神经网络的构造
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id3">
   3.4.2 神经网络中常见的层
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id4">
   3.4.3 模型示例
  </a>
 </li>
</ul>

                        </nav>
                    </div>
                </div>
            </div>
            <main id="main-content" role="main">
                
              <div>
                
  <section class="tex2jax_ignore mathjax_ignore" id="id1">
<h1>3.4 模型构建<a class="headerlink" href="#id1" title="永久链接至标题">#</a></h1>
<section id="id2">
<h2>3.4.1 神经网络的构造<a class="headerlink" href="#id2" title="永久链接至标题">#</a></h2>
<p>PyTorch中神经网络构造一般是基于 Module 类的模型来完成的，它让模型构造更加灵活。</p>
<p>Module 类是 nn 模块里提供的一个模型构造类，是所有神经⽹网络模块的基类，我们可以继承它来定义我们想要的模型。下面继承 Module 类构造多层感知机。这里定义的 MLP 类重载了 Module 类的 init 函数和 forward 函数。它们分别用于创建模型参数和定义前向计算。前向计算也即正向传播。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  <span class="c1"># 声明带有模型参数的层，这里声明了两个全连接层</span>
  <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
    <span class="c1"># 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数</span>
    <span class="nb">super</span><span class="p">(</span><span class="n">MLP</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">hidden</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">256</span><span class="p">)</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span><span class="mi">10</span><span class="p">)</span>
    
   <span class="c1"># 定义模型的前向计算，即如何根据输入x计算返回所需要的模型输出</span>
  <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
    <span class="n">o</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
    <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>   
</pre></div>
</div>
<p>以上的 MLP 类中⽆须定义反向传播函数。系统将通过⾃动求梯度⽽自动⽣成反向传播所需的 backward 函数。</p>
<p>我们可以实例化 MLP 类得到模型变量 net 。下⾯的代码初始化 net 并传入输⼊数据 X 做一次前向计算。其中， net(X) 会调用 MLP 继承⾃自 Module 类的 <strong>call</strong> 函数，这个函数将调⽤用 MLP 类定义的forward 函数来完成前向计算。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">784</span><span class="p">)</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
<span class="n">net</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">MLP</span><span class="p">(</span>
  <span class="p">(</span><span class="n">hidden</span><span class="p">):</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">784</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  <span class="p">(</span><span class="n">act</span><span class="p">):</span> <span class="n">ReLU</span><span class="p">()</span>
  <span class="p">(</span><span class="n">output</span><span class="p">):</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="p">)</span>
<span class="n">tensor</span><span class="p">([[</span> <span class="mf">0.0149</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.2641</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0040</span><span class="p">,</span>  <span class="mf">0.0945</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1277</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0092</span><span class="p">,</span>  <span class="mf">0.0343</span><span class="p">,</span>  <span class="mf">0.0627</span><span class="p">,</span>
         <span class="o">-</span><span class="mf">0.1742</span><span class="p">,</span>  <span class="mf">0.1866</span><span class="p">],</span>
        <span class="p">[</span> <span class="mf">0.0738</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1409</span><span class="p">,</span>  <span class="mf">0.0790</span><span class="p">,</span>  <span class="mf">0.0597</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1572</span><span class="p">,</span>  <span class="mf">0.0479</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.0519</span><span class="p">,</span>  <span class="mf">0.0211</span><span class="p">,</span>
         <span class="o">-</span><span class="mf">0.1435</span><span class="p">,</span>  <span class="mf">0.1958</span><span class="p">]],</span> <span class="n">grad_fn</span><span class="o">=&lt;</span><span class="n">AddmmBackward</span><span class="o">&gt;</span><span class="p">)</span>
</pre></div>
</div>
<p>注意，这里并没有将 Module 类命名为 Layer (层)或者 Model (模型)之类的名字，这是因为该类是一个可供⾃由组建的部件。它的子类既可以是⼀个层(如PyTorch提供的 Linear 类)，⼜可以是一个模型(如这里定义的 MLP 类)，或者是模型的⼀个部分。</p>
</section>
<section id="id3">
<h2>3.4.2 神经网络中常见的层<a class="headerlink" href="#id3" title="永久链接至标题">#</a></h2>
<p>深度学习的一个魅力在于神经网络中各式各样的层，例如全连接层、卷积层、池化层与循环层等等。虽然PyTorch提供了⼤量常用的层，但有时候我们依然希望⾃定义层。这里我们会介绍如何使用 Module 来自定义层，从而可以被反复调用。</p>
<ul class="simple">
<li><p><strong>不含模型参数的层</strong></p></li>
</ul>
<p>我们先介绍如何定义一个不含模型参数的自定义层。下⾯构造的 MyLayer 类通过继承 Module 类自定义了一个<strong>将输入减掉均值后输出</strong>的层，并将层的计算定义在了 forward 函数里。这个层里不含模型参数。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="k">class</span> <span class="nc">MyLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">MyLayer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>  
</pre></div>
</div>
<p>测试，实例化该层，然后做前向计算</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">layer</span> <span class="o">=</span> <span class="n">MyLayer</span><span class="p">()</span>
<span class="n">layer</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">))</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mf">2.</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span>  <span class="mf">0.</span><span class="p">,</span>  <span class="mf">1.</span><span class="p">,</span>  <span class="mf">2.</span><span class="p">])</span>
</pre></div>
</div>
<ul class="simple">
<li><p><strong>含模型参数的层</strong></p></li>
</ul>
<p>我们还可以自定义含模型参数的自定义层。其中的模型参数可以通过训练学出。</p>
<p>Parameter 类其实是 Tensor 的子类，如果一 个 Tensor 是 Parameter ，那么它会⾃动被添加到模型的参数列表里。所以在⾃定义含模型参数的层时，我们应该将参数定义成 Parameter ，除了直接定义成 Parameter 类外，还可以使⽤ ParameterList 和 ParameterDict 分别定义参数的列表和字典。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">MyListDense</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">MyListDense</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">params</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ParameterList</span><span class="p">([</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">)])</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">)):</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
        <span class="k">return</span> <span class="n">x</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">MyListDense</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">MyDictDense</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">MyDictDense</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">params</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ParameterDict</span><span class="p">({</span>
                <span class="s1">&#39;linear1&#39;</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)),</span>
                <span class="s1">&#39;linear2&#39;</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
        <span class="p">})</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="s1">&#39;linear3&#39;</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">))})</span> <span class="c1"># 新增</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">choice</span><span class="o">=</span><span class="s1">&#39;linear1&#39;</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">mm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="n">choice</span><span class="p">])</span>

<span class="n">net</span> <span class="o">=</span> <span class="n">MyDictDense</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
</pre></div>
</div>
<p>下面给出常见的神经网络的一些层，比如卷积层、池化层，以及较为基础的AlexNet，LeNet等。</p>
<ul class="simple">
<li><p><strong>二维卷积层</strong></p></li>
</ul>
<p>二维卷积层将输入和卷积核做互相关运算，并加上一个标量偏差来得到输出。卷积层的模型参数包括了卷积核和标量偏差。在训练模型的时候，通常我们先对卷积核随机初始化，然后不断迭代卷积核和偏差。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="c1"># 卷积运算（二维互相关）</span>
<span class="k">def</span> <span class="nf">corr2d</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">K</span><span class="p">):</span> 
    <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">K</span><span class="o">.</span><span class="n">shape</span>
    <span class="n">X</span><span class="p">,</span> <span class="n">K</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">K</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
    <span class="n">Y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">h</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">w</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">Y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
        <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">Y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
            <span class="n">Y</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">h</span><span class="p">,</span> <span class="n">j</span><span class="p">:</span> <span class="n">j</span> <span class="o">+</span> <span class="n">w</span><span class="p">]</span> <span class="o">*</span> <span class="n">K</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">Y</span>

<span class="c1"># 二维卷积层</span>
<span class="k">class</span> <span class="nc">Conv2D</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">Conv2D</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">kernel_size</span><span class="p">))</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">corr2d</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span>
</pre></div>
</div>
<p>卷积窗口形状为 <span class="math notranslate nohighlight">\(p \times q\)</span> 的卷积层称为 <span class="math notranslate nohighlight">\(p \times q\)</span>  卷积层。同样，  <span class="math notranslate nohighlight">\(p \times q\)</span> 卷积或 <span class="math notranslate nohighlight">\(p \times q\)</span> 卷积核说明卷积核的高和宽分别为 <span class="math notranslate nohighlight">\(p\)</span> 和 <span class="math notranslate nohighlight">\(q\)</span>。</p>
<p>填充(padding)是指在输⼊高和宽的两侧填充元素(通常是0元素)。</p>
<p>下面的例子里我们创建一个⾼和宽为3的二维卷积层，然后设输⼊高和宽两侧的填充数分别为1。给定一 个高和宽为8的输入，我们发现输出的高和宽也是8。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="c1"># 定义一个函数来计算卷积层。它对输入和输出做相应的升维和降维</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="c1"># 定义一个函数来计算卷积层。它对输入和输出做相应的升维和降维</span>
<span class="k">def</span> <span class="nf">comp_conv2d</span><span class="p">(</span><span class="n">conv2d</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
    <span class="c1"># (1, 1)代表批量大小和通道数</span>
    <span class="n">X</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">view</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
    <span class="n">Y</span> <span class="o">=</span> <span class="n">conv2d</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">Y</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">Y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:])</span> <span class="c1"># 排除不关心的前两维:批量和通道</span>


<span class="c1"># 注意这里是两侧分别填充1⾏或列，所以在两侧一共填充2⾏或列</span>
<span class="n">conv2d</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span><span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

<span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">8</span><span class="p">)</span>
<span class="n">comp_conv2d</span><span class="p">(</span><span class="n">conv2d</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">Size</span><span class="p">([</span><span class="mi">8</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span>
</pre></div>
</div>
<p>当卷积核的高和宽不同时，我们也可以通过设置高和宽上不同的填充数使输出和输入具有相同的高和宽。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># 使用高为5、宽为3的卷积核。在⾼和宽两侧的填充数分别为2和1</span>
<span class="n">conv2d</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">comp_conv2d</span><span class="p">(</span><span class="n">conv2d</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">Size</span><span class="p">([</span><span class="mi">8</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span>
</pre></div>
</div>
<p>在二维互相关运算中，卷积窗口从输入数组的最左上方开始，按从左往右、从上往下 的顺序，依次在输⼊数组上滑动。我们将每次滑动的行数和列数称为步幅(stride)。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">conv2d</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
<span class="n">comp_conv2d</span><span class="p">(</span><span class="n">conv2d</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">Size</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
</pre></div>
</div>
<p>填充可以增加输出的高和宽。这常用来使输出与输入具有相同的高和宽。</p>
<p>步幅可以减小输出的高和宽，例如输出的高和宽仅为输入的高和宽的 ( 为大于1的整数)。</p>
<ul class="simple">
<li><p><strong>池化层</strong></p></li>
</ul>
<p>池化层每次对输入数据的一个固定形状窗口(⼜称池化窗口)中的元素计算输出。不同于卷积层里计算输⼊和核的互相关性，池化层直接计算池化窗口内元素的最大值或者平均值。该运算也 分别叫做最大池化或平均池化。在二维最⼤池化中，池化窗口从输入数组的最左上方开始，按从左往右、从上往下的顺序，依次在输⼊数组上滑动。当池化窗口滑动到某⼀位置时，窗口中的输入子数组的最大值即输出数组中相应位置的元素。</p>
<p>下面把池化层的前向计算实现在<code class="docutils literal notranslate"><span class="pre">pool2d</span></code>函数里。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>

<span class="k">def</span> <span class="nf">pool2d</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">pool_size</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;max&#39;</span><span class="p">):</span>
    <span class="n">p_h</span><span class="p">,</span> <span class="n">p_w</span> <span class="o">=</span> <span class="n">pool_size</span>
    <span class="n">Y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">p_h</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">X</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">p_w</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">Y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
        <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">Y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
            <span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;max&#39;</span><span class="p">:</span>
                <span class="n">Y</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">p_h</span><span class="p">,</span> <span class="n">j</span><span class="p">:</span> <span class="n">j</span> <span class="o">+</span> <span class="n">p_w</span><span class="p">]</span><span class="o">.</span><span class="n">max</span><span class="p">()</span>
            <span class="k">elif</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;avg&#39;</span><span class="p">:</span>
                <span class="n">Y</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">p_h</span><span class="p">,</span> <span class="n">j</span><span class="p">:</span> <span class="n">j</span> <span class="o">+</span> <span class="n">p_w</span><span class="p">]</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">Y</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">],</span> <span class="p">[</span><span class="mi">6</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">8</span><span class="p">]],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
<span class="n">pool2d</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">tensor</span><span class="p">([[</span><span class="mf">4.</span><span class="p">,</span> <span class="mf">5.</span><span class="p">],</span>
	<span class="p">[</span><span class="mf">7.</span><span class="p">,</span> <span class="mf">8.</span><span class="p">]])</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pool2d</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="s1">&#39;avg&#39;</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">tensor</span><span class="p">([[</span><span class="mf">2.</span><span class="p">,</span> <span class="mf">3.</span><span class="p">],</span>
	<span class="p">[</span><span class="mf">5.</span><span class="p">,</span> <span class="mf">6.</span><span class="p">]])</span>
</pre></div>
</div>
<p>我们可以使用<code class="docutils literal notranslate"><span class="pre">torch.nn</span></code>包来构建神经网络。我们已经介绍了<code class="docutils literal notranslate"><span class="pre">autograd</span></code>包，<code class="docutils literal notranslate"><span class="pre">nn</span></code>包则依赖于<code class="docutils literal notranslate"><span class="pre">autograd</span></code>包来定义模型并对它们求导。一个<code class="docutils literal notranslate"><span class="pre">nn.Module</span></code>包含各个层和一个<code class="docutils literal notranslate"><span class="pre">forward(input)</span></code>方法，该方法返回<code class="docutils literal notranslate"><span class="pre">output</span></code>。</p>
</section>
<section id="id4">
<h2>3.4.3 模型示例<a class="headerlink" href="#id4" title="永久链接至标题">#</a></h2>
<ul class="simple">
<li><p><strong>LeNet</strong></p></li>
</ul>
<p><img alt="3.4.1" src="../_images/3.4.1.png" /></p>
<p>这是一个简单的前馈神经网络 (feed-forward network）（LeNet）。它接受一个输入，然后将它送入下一层，一层接一层的传递，最后给出输出。</p>
<p>一个神经网络的典型训练过程如下：</p>
<ol class="simple">
<li><p>定义包含一些可学习参数(或者叫权重）的神经网络</p></li>
<li><p>在输入数据集上迭代</p></li>
<li><p>通过网络处理输入</p></li>
<li><p>计算 loss (输出和正确答案的距离）</p></li>
<li><p>将梯度反向传播给网络的参数</p></li>
<li><p>更新网络的权重，一般使用一个简单的规则：<code class="docutils literal notranslate"><span class="pre">weight</span> <span class="pre">=</span> <span class="pre">weight</span> <span class="pre">-</span> <span class="pre">learning_rate</span> <span class="pre">*</span> <span class="pre">gradient</span></code></p></li>
</ol>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>


<span class="k">class</span> <span class="nc">Net</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">Net</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="c1"># 输入图像channel：1；输出channel：6；5x5卷积核</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
        <span class="c1"># an affine operation: y = Wx + b</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">16</span> <span class="o">*</span> <span class="mi">5</span> <span class="o">*</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">120</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">120</span><span class="p">,</span> <span class="mi">84</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">84</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># 2x2 Max pooling</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
        <span class="c1"># 如果是方阵,则可以只使用一个数字进行定义</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)),</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_flat_features</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc3</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span>

    <span class="k">def</span> <span class="nf">num_flat_features</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">size</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">:]</span>  <span class="c1"># 除去批处理维度的其他所有维度</span>
        <span class="n">num_features</span> <span class="o">=</span> <span class="mi">1</span>
        <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">size</span><span class="p">:</span>
            <span class="n">num_features</span> <span class="o">*=</span> <span class="n">s</span>
        <span class="k">return</span> <span class="n">num_features</span>


<span class="n">net</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">Net</span><span class="p">(</span>
  <span class="p">(</span><span class="n">conv1</span><span class="p">):</span> <span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
  <span class="p">(</span><span class="n">conv2</span><span class="p">):</span> <span class="n">Conv2d</span><span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
  <span class="p">(</span><span class="n">fc1</span><span class="p">):</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">400</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">120</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  <span class="p">(</span><span class="n">fc2</span><span class="p">):</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">120</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">84</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  <span class="p">(</span><span class="n">fc3</span><span class="p">):</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">84</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="p">)</span>
</pre></div>
</div>
<p>我们只需要定义 <code class="docutils literal notranslate"><span class="pre">forward</span></code> 函数，<code class="docutils literal notranslate"><span class="pre">backward</span></code>函数会在使用<code class="docutils literal notranslate"><span class="pre">autograd</span></code>时自动定义，<code class="docutils literal notranslate"><span class="pre">backward</span></code>函数用来计算导数。我们可以在 <code class="docutils literal notranslate"><span class="pre">forward</span></code> 函数中使用任何针对张量的操作和计算。</p>
<p>一个模型的可学习参数可以通过<code class="docutils literal notranslate"><span class="pre">net.parameters()</span></code>返回</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">params</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
<span class="nb">print</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">params</span><span class="p">))</span>
<span class="nb">print</span><span class="p">(</span><span class="n">params</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">size</span><span class="p">())</span>  <span class="c1"># conv1的权重</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="mi">10</span>
<span class="n">torch</span><span class="o">.</span><span class="n">Size</span><span class="p">([</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span>
</pre></div>
</div>
<p>让我们尝试一个随机的 32x32 的输入。注意:这个网络 (LeNet）的期待输入是 32x32 的张量。如果使用 MNIST 数据集来训练这个网络，要把图片大小重新调整到 32x32。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nb">input</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">)</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
</pre></div>
</div>
<p>清零所有参数的梯度缓存，然后进行随机梯度的反向传播：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">net</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">out</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
</pre></div>
</div>
<p>注意：<code class="docutils literal notranslate"><span class="pre">torch.nn</span></code>只支持小批量处理 (mini-batches）。整个 <code class="docutils literal notranslate"><span class="pre">torch.nn</span></code> 包只支持小批量样本的输入，不支持单个样本的输入。比如，<code class="docutils literal notranslate"><span class="pre">nn.Conv2d</span></code> 接受一个4维的张量，即<code class="docutils literal notranslate"><span class="pre">nSamples</span> <span class="pre">x</span> <span class="pre">nChannels</span> <span class="pre">x</span> <span class="pre">Height</span> <span class="pre">x</span> <span class="pre">Width</span> </code>如果是一个单独的样本，只需要使用<code class="docutils literal notranslate"><span class="pre">input.unsqueeze(0)</span></code> 来添加一个“假的”批大小维度。</p>
<ul class="simple">
<li><p><code class="docutils literal notranslate"><span class="pre">torch.Tensor</span></code> - 一个多维数组，支持诸如<code class="docutils literal notranslate"><span class="pre">backward()</span></code>等的自动求导操作，同时也保存了张量的梯度。</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">nn.Module</span> </code>- 神经网络模块。是一种方便封装参数的方式，具有将参数移动到GPU、导出、加载等功能。</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">nn.Parameter</span> </code>- 张量的一种，当它作为一个属性分配给一个<code class="docutils literal notranslate"><span class="pre">Module</span></code>时，它会被自动注册为一个参数。</p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">autograd.Function</span></code> - 实现了自动求导前向和反向传播的定义，每个<code class="docutils literal notranslate"><span class="pre">Tensor</span></code>至少创建一个<code class="docutils literal notranslate"><span class="pre">Function</span></code>节点，该节点连接到创建<code class="docutils literal notranslate"><span class="pre">Tensor</span></code>的函数并对其历史进行编码。</p></li>
</ul>
<p>下面再介绍一个比较基础的案例AlexNet</p>
<ul class="simple">
<li><p><strong>AlexNet</strong></p></li>
</ul>
<p><img alt="3.4.2" src="../_images/3.4.2.png" /></p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">AlexNet</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">AlexNet</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">96</span><span class="p">,</span> <span class="mi">11</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="c1"># in_channels, out_channels, kernel_size, stride, padding</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="c1"># kernel_size, stride</span>
            <span class="c1"># 减小卷积窗口，使用填充为2来使得输入与输出的高和宽一致，且增大输出通道数</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">96</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
            <span class="c1"># 连续3个卷积层，且使用更小的卷积窗口。除了最后的卷积层外，进一步增大了输出通道数。</span>
            <span class="c1"># 前两个卷积层后不使用池化层来减小输入的高和宽</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">384</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">384</span><span class="p">,</span> <span class="mi">384</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">384</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="p">)</span>
         <span class="c1"># 这里全连接层的输出个数比LeNet中的大数倍。使用丢弃层来缓解过拟合</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">256</span><span class="o">*</span><span class="mi">5</span><span class="o">*</span><span class="mi">5</span><span class="p">,</span> <span class="mi">4096</span><span class="p">),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4096</span><span class="p">,</span> <span class="mi">4096</span><span class="p">),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">),</span>
            <span class="c1"># 输出层。由于这里使用Fashion-MNIST，所以用类别数为10，而非论文中的1000</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4096</span><span class="p">,</span> <span class="mi">10</span><span class="p">),</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
        <span class="n">feature</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
        <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span><span class="p">(</span><span class="n">feature</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">output</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">net</span> <span class="o">=</span> <span class="n">AlexNet</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">AlexNet</span><span class="p">(</span>
  <span class="p">(</span><span class="n">conv</span><span class="p">):</span> <span class="n">Sequential</span><span class="p">(</span>
    <span class="p">(</span><span class="mi">0</span><span class="p">):</span> <span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">96</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">11</span><span class="p">,</span> <span class="mi">11</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
    <span class="p">(</span><span class="mi">1</span><span class="p">):</span> <span class="n">ReLU</span><span class="p">()</span>
    <span class="p">(</span><span class="mi">2</span><span class="p">):</span> <span class="n">MaxPool2d</span><span class="p">(</span><span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">ceil_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    <span class="p">(</span><span class="mi">3</span><span class="p">):</span> <span class="n">Conv2d</span><span class="p">(</span><span class="mi">96</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
    <span class="p">(</span><span class="mi">4</span><span class="p">):</span> <span class="n">ReLU</span><span class="p">()</span>
    <span class="p">(</span><span class="mi">5</span><span class="p">):</span> <span class="n">MaxPool2d</span><span class="p">(</span><span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">ceil_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
    <span class="p">(</span><span class="mi">6</span><span class="p">):</span> <span class="n">Conv2d</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">384</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
    <span class="p">(</span><span class="mi">7</span><span class="p">):</span> <span class="n">ReLU</span><span class="p">()</span>
    <span class="p">(</span><span class="mi">8</span><span class="p">):</span> <span class="n">Conv2d</span><span class="p">(</span><span class="mi">384</span><span class="p">,</span> <span class="mi">384</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
    <span class="p">(</span><span class="mi">9</span><span class="p">):</span> <span class="n">ReLU</span><span class="p">()</span>
    <span class="p">(</span><span class="mi">10</span><span class="p">):</span> <span class="n">Conv2d</span><span class="p">(</span><span class="mi">384</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
    <span class="p">(</span><span class="mi">11</span><span class="p">):</span> <span class="n">ReLU</span><span class="p">()</span>
    <span class="p">(</span><span class="mi">12</span><span class="p">):</span> <span class="n">MaxPool2d</span><span class="p">(</span><span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">ceil_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  <span class="p">)</span>
  <span class="p">(</span><span class="n">fc</span><span class="p">):</span> <span class="n">Sequential</span><span class="p">(</span>
    <span class="p">(</span><span class="mi">0</span><span class="p">):</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">6400</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">4096</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    <span class="p">(</span><span class="mi">1</span><span class="p">):</span> <span class="n">ReLU</span><span class="p">()</span>
    <span class="p">(</span><span class="mi">2</span><span class="p">):</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
    <span class="p">(</span><span class="mi">3</span><span class="p">):</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">4096</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">4096</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
    <span class="p">(</span><span class="mi">4</span><span class="p">):</span> <span class="n">ReLU</span><span class="p">()</span>
    <span class="p">(</span><span class="mi">5</span><span class="p">):</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
    <span class="p">(</span><span class="mi">6</span><span class="p">):</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">4096</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  <span class="p">)</span>
<span class="p">)</span>
</pre></div>
</div>
</section>
</section>


              </div>
              
            </main>
            <footer class="footer-article noprint">
                
    <!-- Previous / next buttons -->
<div class='prev-next-area'>
    <a class='left-prev' id="prev-link" href="3.3%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%85%A5.html" title="上一页 页">
        <i class="fas fa-angle-left"></i>
        <div class="prev-next-info">
            <p class="prev-next-subtitle">上一页</p>
            <p class="prev-next-title">3.3 数据读入</p>
        </div>
    </a>
    <a class='right-next' id="next-link" href="3.5%20%E6%A8%A1%E5%9E%8B%E5%88%9D%E5%A7%8B%E5%8C%96.html" title="下一页 页">
    <div class="prev-next-info">
        <p class="prev-next-subtitle">下一页</p>
        <p class="prev-next-title">3.5 模型初始化</p>
    </div>
    <i class="fas fa-angle-right"></i>
    </a>
</div>
            </footer>
        </div>
    </div>
    <div class="footer-content row">
        <footer class="col footer"><p>
  
    By ZhikangNiu<br/>
  
      &copy; Copyright 2022, ZhikangNiu.<br/>
</p>
        </footer>
    </div>
    
</div>


      </div>
    </div>
  
  <!-- Scripts loaded after <body> so the DOM is not blocked -->
  <script src="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script>


  </body>
</html>